"""
使用微分的思路目前不可行
"""
from collections import defaultdict

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.losses import get_loss_f
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
from exps.patterns import data_generator
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.001,
    epochs=401,
    dimension=6,
    loss='dis',
    beta=80,
    img_id=6,
    random_seed=224
)

wandb.init(project="exps", config=hyperparameter_defaults, group='ode')

config = wandb.config
set_seed(config.random_seed)


class ODEVAE(nn.Module):
    def __init__(self, img_size, encoder, decoder, latent_dim, group=1):
        """
        Class which defines model and forward pass.

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).
        """
        super(ODEVAE, self).__init__()

        if list(img_size[1:]) not in [[32, 32], [64, 64]]:
            raise RuntimeError(
                "{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(
                    img_size))

        self.latent_dim = latent_dim
        self.group = group
        self.img_size = img_size
        self.num_pixels = self.img_size[1] * self.img_size[2]
        self.encoder = encoder(img_size, self.latent_dim * group)
        self.decoder = decoder(img_size, self.latent_dim)
        self.differential = nn.Sequential(nn.Linear(self.latent_dim, 256), nn.ReLU(),
                                          nn.Linear(256, 256), nn.ReLU(),
                                          nn.Linear(256, 1))

        self.reset_parameters()

    def error(self, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            batch_size, latent_dim = mean.shape
            std = torch.exp(0.5 * logvar)
            eps = torch.zeros_like(std)
            return mean + eps * std
        else:
            # Reconstruction mode
            return mean

    def forward(self, x):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        mu, logvar = self.encoder(x)
        error = self.error(logvar)
        sample_latents = mu + error
        reconstruct = self.decoder(mu)
        p_reconstruct = self.decoder(sample_latents)

        diff = self.differential(error.detach())

        return p_reconstruct, reconstruct, diff, (mu, logvar), sample_latents

    def reset_parameters(self):
        self.apply(weights_init)

    def sample_latent(self, x):
        """
        Return latent distribution and samples.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        return latent_dist, latent_sample


def binary_cross_entropy(input, target):
    bs = input.size(0)
    l = input * target.log() + (1 - input) * (1 - target).log()
    return l


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    anneal_reg = torch.linspace(0, 1, epochs * len(dl) + 1).tolist()
    for e in trange(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img, label in dl:
            batch_size = img.size(0)
            img = img.view(-1, 1, 64, 64)
            p_reconstruct, recon_batch, diff, latent_dist, latent_sample = model(img)

            diff_loss = F.mse_loss(diff,
                                   F.binary_cross_entropy(recon_batch,
                                                          img.data,
                                                          reduction='none').view(batch_size, -1).sum(1, True))
            loss = loss_f(img, p_reconstruct, latent_dist, model.training,
                          storer, latent_sample=latent_sample) + diff_loss

            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_set.append(loss.item())
            storer['diff_loss'].append(diff_loss.item())

        for img, label in dl:
            batch_size = img.size(0)
            img = img.view(-1, 1, 64, 64)
            mu, logvar = model.encoder(img)
            mu, logvar = mu.detach(), logvar.detach()
            error = model.error(logvar)

            p_recon = model.decoder(mu + error)
            recon = model.decoder(mu)
            with torch.no_grad():
                diff = model.differential(error)
            diff_loss = F.mse_loss(diff,
                                   binary_cross_entropy(recon, p_recon).view(batch_size, -1).sum(1, True), )
            opt.zero_grad()
            diff_loss.backward()
            opt.step()
            break

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)
        kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_loss_' in k])
        _, index = kl_loss.sort(descending=True)



        with torch.no_grad():
            mu, log_var = latent_dist
            var = log_var.exp().cpu()
            mu = mu.cpu()
            for d in range(dim):
                storer[f"mu_{d}"] = wandb.Histogram(mu[:, d])
                storer[f'var_{d}'] = wandb.Histogram(var[:, d])

        if e % 50 == 1:
            storer['recon'] = wandb.Image(recon_batch[0, 0])
            storer['img'] = wandb.Image(img[0, 0])

        if (e + 1) % 50 == 0:
            with torch.no_grad():
                mu, _ = vae.encoder(dataset.cuda())
                points = mu[:, index[:3]].cpu()
                # points = 10*points
                colors = labels.reshape(40 * 40, 2) * torch.tensor([8, 8]) + 7
                point_cloud = torch.cat([points, colors, torch.ones(40 * 40, 1) * 255], 1).numpy()
                storer['ponit_cloud'] = wandb.Object3D(point_cloud)

        wandb.log(storer, sync=False, step=e)
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')

imgs, labels = data_generator.gen_translation(img, img.shape)

epochs = config.epochs
dim = config.dimension
beta = config.beta

dataset = imgs.reshape(-1, 1, 64, 64)
dl = DataLoader(list(zip(dataset.cuda(), labels.cuda())),
                num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = ODEVAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim, 1)
# vae.load_state_dict(torch.load('tmp_model.pt'))
vae.cuda()
vae.train()

# wandb.watch(vae)
loss_f = None
if config.loss == 'betaH':
    loss_f = get_loss_f('betaH', betaH_B=beta,
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
elif config.loss == 'btcvae':
    loss_f = get_loss_f('btcvae', btcvae_A=1, btcvae_B=beta, btcvae_G=1, n_data=len(dataset),
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
elif config.loss == 'dis':
    loss_f = get_loss_f('Dis', betaH_B=beta,
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
else:
    raise Exception('unknown loss')

storer = train(dl, vae, loss_f, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_loss_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
torch.save(vae.state_dict(), 'tmp_model.pt')

vae.eval()

vae.cpu()

del config

# traversal
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
# 新建一个3d绘图对象
ax = Axes3D(fig)
with torch.no_grad():
    X, Y = torch.meshgrid(torch.linspace(-1, 1, 20), torch.linspace(-1, 1, 20))
    l = [torch.zeros_like(X) for i in index]
    l[index[0]] = X
    l[index[1]] = Y
    input = torch.stack(l, -1)

    z = vae.differential(input.view(-1, dim)).view(20, 20)
    ax.plot_surface(X.numpy(), Y.numpy(), z.numpy(), rstride=1, cstride=1, cmap="rainbow")

wandb.log({'traversal_diff': wandb.Image(fig)})

fig = plot_projection(imgs[::4, ::4], vae, dim=index[:2])
wandb.log({'projection_xy': wandb.Image(fig)})
# wandb.save('strip.py')
plt.show()
wandb.join()
